import os

import matplotlib.pyplot as plt
import numpy as np

for i in range(5):
    npy_input = np.load(f"./dataset/input/merged_{i}.npy")
    npy_output = np.load(f"./dataset/output/merged_{i}.npy")

    img1 = np.squeeze(npy_input[0])
    img2 = np.squeeze(npy_output[0])

    # show the input image and output image
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(img1, cmap="gray")
    plt.title("Input Image")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(img2, cmap="gray")
    plt.title("Output Image")
    plt.axis("off")

    plt.savefig(os.path.join("dataset", f"output{i}.png"))
